#from datetimei import date, datetime
import matplotlib.pyplot as plt
#from torchdiffeq import odeint
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchsde
import torch.nn.functional as F
import torch.optim as optim
# from loguru import logger
# from scipy.integrate import odeint
from torch.autograd import grad
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import random_split
from sklearn.model_selection import train_test_split
import argparse
import csv
import random
import time
torch.set_default_dtype(torch.float64)

class Net(nn.Module):

    def __init__(self, input_dim, layers):
        super(Net, self).__init__()
        self.input_dim = input_dim
        self.layers = layers
        self.activation = nn.ReLU
        self.net = self._build_net()

    def _build_net(self):
        layers = []
        layers.extend([nn.Linear(self.input_dim, self.layers[0]), self.activation()])
        for i in range(len(self.layers) - 1):
            layers.extend([nn.Linear(self.layers[i], self.layers[i + 1]), self.activation()])
        layers.append(nn.Linear(self.layers[i], self.input_dim))    
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)
    
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='PI_ROPF')
    ### NN HYPERPARAMS
    parser.add_argument('--seed', type=int, default=123, help='random seed')
    parser.add_argument('--nHiddenUnit', type=int, default = 50, help='number of hidden units')
    parser.add_argument('--activation', type=str, default = "RELU", help='activation_function')
    parser.add_argument('--optimizer', type=int, default = 3, help='GD algorithm')
    parser.add_argument('--lr', type=float, default = 1e-3, help='total number of datapoints')
    parser.add_argument('--batchsize', type=int, default = 50, help='training batch size')
    parser.add_argument('--train_test_and_valid_split', type=float, default = .2)
    parser.add_argument('--normalize', type=bool, default = False)
    parser.add_argument('--max_epochs', type=int, default = 700, help='max training epochs')
    parser.add_argument('--max_patience', type=int, default = 5)
    parser.add_argument('--nSamples', type=int, default = 10000, help='number of layers')
    parser.add_argument('--nLayer', type=int, default = 2, help='number of layers')
    parser.add_argument('--id', type=int, default = 5, help='number of layers')

    args = parser.parse_args()
    args = vars(args) # change to dictionary

    set_seed(args['seed'])
    batch_size = args['batchsize']
    # Number of assets
    num_assets = 50

    # arch_list = [args['nHiddenUnit']] * args['nLayer']
    # arch_list.insert(0,num_assets) 
    # arch_list.append(num_assets)

    init_asset_price_training = np.load(f'portfolio_data/asset_prices_training_{int(args['nSamples']*.8)}.npy').T
    init_asset_price_validation = np.load(f'portfolio_data/asset_prices_validation_{int(args['nSamples']*.1)}.npy').T
    init_asset_price_test = np.load(f'portfolio_data/asset_prices_test_{int(args['nSamples']*.1)}.npy').T
    final_asset_price_training = np.load(f'portfolio_data/dyn_asset_prices_training_{int(args['nSamples']*.8)}.npy')
    final_asset_price_validation = np.load(f'portfolio_data/dyn_asset_prices_validation_{int(args['nSamples']*.1)}.npy')
    final_asset_price_test = np.load(f'portfolio_data/dyn_asset_prices_test_{int(args['nSamples']*.1)}.npy')

    X_train, X_valid, X_test, Y_train, Y_valid, Y_test = torch.from_numpy(init_asset_price_training).type(torch.float64), torch.from_numpy(init_asset_price_validation).type(torch.float64), torch.from_numpy(init_asset_price_test).type(torch.float64), torch.from_numpy(final_asset_price_training).type(torch.float64), torch.from_numpy(final_asset_price_validation).type(torch.float64), torch.from_numpy(final_asset_price_test).type(torch.float64)
    train_data = TensorDataset(X_train, Y_train)   # X:(1024,2) Y:(1024)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    valid_data = TensorDataset(X_valid, Y_valid)   # X:(1024,2) Y:(1024)
    valid_loader = DataLoader(valid_data, batch_size=len(valid_data), shuffle=False)

    test_data = TensorDataset(X_test, Y_test)   # X:(1024,2) Y:(1024)
    test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)

    # Initialize model, optimizer, and loss function
    input_dim = num_assets
    model = Net(input_dim, [args['nHiddenUnit']] * args['nLayer'])
    print(model)

    if args['optimizer']==1:
        optimizer = torch.optim.Adam(model.parameters(), lr = args['lr'])
    elif args['optimizer']==2:
        optimizer = torch.optim.Adadelta(model.parameters(), lr = args['lr'])
    elif args['optimizer']==3:
        optimizer = torch.optim.SGD(model.parameters(), lr = args['lr'])

    loss_fn = nn.MSELoss()
    
    # early stopping parameters
    patience = 0
    min_loss = 100000
    max_patience = args['max_patience']

    for epoch in range(args['max_epochs']):
        for i, (x, y) in enumerate(train_loader):
            optimizer.zero_grad()
            #y = y.permute(1,0)
            #print( torch.unsqueeze(x,1).size())
            S_pred = model(x).double()
            loss = loss_fn(S_pred, y)
            loss.backward()
            optimizer.step()
            print(f'Training loss: {loss.item()}')
            #print(f'Epoch {epoch}, training loss: {loss.item()}')
            # plt.figure()
            # plt.plot(time_points[:int(k*len(t))].numpy(), S_pred.squeeze(0).detach().numpy()[:,:,0])
            # plt.plot(time_points[:int(k*len(t))].numpy(), y[:,:int(k*len(t)),:].permute(1,0,2).squeeze(0).detach().numpy()[:,:,0])
            # plt.show()
            # if epoch % 100 == 0:
            #     print(f'Epoch {epoch}, Loss: {loss.item()}')
        model.eval()
        for i, (x, y) in enumerate(valid_loader):
            #y = y.permute(1,0)
            S_pred = model(x)
            loss = loss_fn(S_pred, y)
            loss.backward()
            optimizer.step()
            print(f'Valid loss: {loss.item()}')
            # plt.figure()
            # plt.plot(time_points[:int(k*len(t))].numpy(), S_pred.detach().numpy()[:,0,0])
            # plt.plot(time_points[:int(k*len(t))].numpy(), y[:int(k*len(t)),0,0])
            # plt.show()
            # if epoch % 100 == 0:
            #     print(f'Epoch {epoch}, Loss: {loss.item()}')
        if loss<min_loss:
            torch.save(model.state_dict(), f"best_Net_model/model_{args['id']}.pt")
            min_loss = loss
            patience = 0
        else:
            patience += 1
            if patience>= max_patience:
                break
    
    model.load_state_dict(torch.load(f"best_Net_model/model_{args['id']}.pt"))
    model.eval()
    
    for i, (x, y) in enumerate(test_loader):
        #y = y.permute(1,0)
        S_pred = model(x)
        loss = loss_fn(S_pred, y)
        print(f'Test loss: {loss.item()}')

    record = {
        'id' : [args['id']],
        'MSE_valid' : [min_loss.detach().numpy()],
        'MSE_test' : [loss.detach().numpy()]
        }
        
    torch.save(model.state_dict(), f"best_Net_model/model_{args['id']}.pt")
    df = pd.DataFrame(record)
    df.to_csv('Net_static_portfolio_results.csv',mode='a', header=False, index=False)